# -*- coding: utf-8 -*-
"""
Code that trains decision tree to match legislative sections.
Last updated: April 21, 2023
Author: Karen Simpson and Jeremy Gelman
"""
import pandas as pd
import numpy as np
import time
import os
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split # splitting the data
from sklearn.metrics import precision_score, accuracy_score, recall_score # evaluation metric
from sklearn.metrics import confusion_matrix # evaluation metric

os.chdir('./replication')

handcode = pd.read_excel("training_testdata.xlsx")

scores_first = handcode.apply(
    lambda row: gen_hash_first100(row['clean_txt1'], row['clean_txt2'], 2, row['sec1'], row['sec2']),
    axis = 1,
    result_type='expand'
    )

handcode = pd.concat([handcode,scores_first], axis='columns')
handcode = handcode.rename(columns={0: "first_txt1_tk_2", 1: "first_txt2_tk_2", 2: "first_shared_keys_2", 3: "first_txt1_ad_2", 4:"first_txt2_ad_2", 5:"first_scope_2"})

scores = handcode.apply(
    lambda row: gen_hash(row['clean_txt1'], row['clean_txt2'], 2, row['sec1'], row['sec2']),
    axis = 1,
    result_type='expand'
    )

handcode = pd.concat([handcode,scores], axis='columns')
handcode = handcode.rename(columns={0: "txt1_tk_2", 1: "txt2_tk_2", 2: "shared_keys_2", 3: "txt1_ad_2", 4:"txt2_ad_2", 5:"scope_2"})

scores = handcode.apply(
    lambda row: gen_hash(row['clean_txt1'], row['clean_txt2'], 3, row['sec1'], row['sec2']),
    axis = 1,
    result_type='expand'
    )

handcode = pd.concat([handcode,scores], axis='columns')
handcode = handcode.rename(columns={0: "txt1_tk_3", 1: "txt2_tk_3", 2: "shared_keys_3", 3: "txt1_ad_3", 4:"txt2_ad_3", 5:"scope_3"})

scores = handcode.apply(
    lambda row: gen_hash(row['clean_txt1'], row['clean_txt2'], 4, row['sec1'], row['sec2']),
    axis = 1,
    result_type='expand'
    )

handcode = pd.concat([handcode,scores], axis='columns')
handcode = handcode.rename(columns={0: "txt1_tk_4", 1: "txt2_tk_4", 2: "shared_keys_4", 3: "txt1_ad_4", 4:"txt2_ad_4", 5:"scope_4"})

scores = handcode.apply(
    lambda row: gen_hash(row['clean_txt1'], row['clean_txt2'], 5, row['sec1'], row['sec2']),
    axis = 1,
    result_type='expand'
    )

handcode = pd.concat([handcode,scores], axis='columns')
handcode = handcode.rename(columns={0: "txt1_tk_5", 1: "txt2_tk_5", 2: "shared_keys_5", 3: "txt1_ad_5", 4:"txt2_ad_5", 5:"scope_5"})

scores = handcode.apply(
    lambda row: gen_hash(row['clean_txt1'], row['clean_txt2'], 10, row['sec1'], row['sec2']),
    axis = 1,
    result_type='expand'
    )

handcode = pd.concat([handcode,scores], axis='columns')
handcode = handcode.rename(columns={0: "txt1_tk_10", 1: "txt2_tk_10", 2: "shared_keys_10", 3: "txt1_ad_10", 4:"txt2_ad_10", 5:"scope_10"})

#Add blocks    
start = time.process_time()
scores = handcode.apply(
    lambda row: blocks(row['clean_txt1'], row['clean_txt2']), 
    axis=1,
    result_type='expand'
    )
print(time.process_time() - start)
    
handcode = pd.concat([handcode,scores], axis='columns')
handcode = handcode.rename(columns={0: "long_block", 1: "num_blocks", 2: "total_blocklength", 3: "ave_blocklength", 4:"perblock_txt1", 5:"perblock_txt2"})

#########################################################################################################


X_var = np.asarray(handcode[['txt1_ad_5', 'txt2_ad_5', 'scope_10', 'num_blocks', 'perblock_txt1', 'perblock_txt2', "first_shared_keys_2", "first_txt1_ad_2", "first_txt2_ad_2"]])

y_var_minor = np.asarray(handcode['minor_differences'])

X_train_minor, X_test_minor, y_train_minor, y_test_minor = train_test_split(X_var, y_var_minor, test_size = 0.25, random_state = 4)

clf_model_minor = DecisionTreeClassifier(random_state=4,max_depth=5, min_samples_leaf=5)   
clf_model_minor.fit(X_train_minor,y_train_minor)
y_predict_minor = clf_model_minor.predict(X_test_minor)
precision = precision_score(y_test_minor,y_predict_minor)
accuracy = accuracy_score(y_test_minor,y_predict_minor)
recall = recall_score(y_test_minor, y_predict_minor, average='binary')

cm_minor = confusion_matrix(y_test_minor, y_predict_minor, labels = [0,1])
print(cm_minor)
print('\nTrue Positives(TP) = ', cm_minor[0,0])
print('\nTrue Negatives(TN) = ', cm_minor[1,1])
print('\nFalse Positives(FP) = ', cm_minor[0,1])
print('\nFalse Negatives(FN) = ', cm_minor[1,0])